{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Readme"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- This example illustrates processing CoNLL-2003 corpus with BERT for sequence tagging from https://github.com/huggingface/transformers\n",
    "- Using this script it is possible to achieve 91.4 F1 entity-level score and 92.8 F1 token-level score with bert-base\n",
    "- With 100 epochs and lr=1e-5, batch_size=8 it is possible to achieve the same score as in https://gluon-nlp.mxnet.io/model_zoo/bert/index.html "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Install dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install pytorch_transformers flair seqeval"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Download CoNLL-2003"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p conll2003\n",
    "!wget https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testa -O ./conll2003/eng.testa\n",
    "!wget https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testb -O ./conll2003/eng.testb\n",
    "!wget https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.train -O ./conll2003/eng.train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Choosing cuda devices if there are multiple\n",
    "\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "import sys\n",
    "\n",
    "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n",
    "logger = logging.getLogger('sequence_tagger_bert')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tesla V100-DGXS-16GB\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "device = torch.device('cuda')\n",
    "n_gpu = torch.cuda.device_count()\n",
    "\n",
    "for i in range(n_gpu):\n",
    "    print(torch.cuda.get_device_name(i))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "CACHE_DIR = 'cache'\n",
    "BATCH_SIZE = 16\n",
    "PRED_BATCH_SIZE = 100\n",
    "MAX_LEN = 128\n",
    "MAX_N_EPOCHS = 4\n",
    "WEIGHT_DECAY = 0.01\n",
    "LEARNING_RATE = 5e-5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7fd95505e530>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "torch.manual_seed(117)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load corpus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:transformers.file_utils:TensorFlow version 2.0.0 available.\n",
      "INFO:transformers.file_utils:PyTorch version 1.3.1 available.\n",
      "INFO:transformers.modeling_xlnet:Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .\n",
      "2019-12-13 00:35:35,712 Reading data from conll2003\n",
      "2019-12-13 00:35:35,713 Train: conll2003/eng.train\n",
      "2019-12-13 00:35:35,713 Dev: conll2003/eng.testa\n",
      "2019-12-13 00:35:35,714 Test: conll2003/eng.testb\n",
      "{\n",
      "    \"TRAIN\": {\n",
      "        \"dataset\": \"TRAIN\",\n",
      "        \"total_number_of_documents\": 14987,\n",
      "        \"number_of_documents_per_class\": {},\n",
      "        \"number_of_tokens_per_tag\": {},\n",
      "        \"number_of_tokens\": {\n",
      "            \"total\": 204567,\n",
      "            \"min\": 1,\n",
      "            \"max\": 113,\n",
      "            \"avg\": 13.649629679055181\n",
      "        }\n",
      "    },\n",
      "    \"TEST\": {\n",
      "        \"dataset\": \"TEST\",\n",
      "        \"total_number_of_documents\": 3684,\n",
      "        \"number_of_documents_per_class\": {},\n",
      "        \"number_of_tokens_per_tag\": {},\n",
      "        \"number_of_tokens\": {\n",
      "            \"total\": 46666,\n",
      "            \"min\": 1,\n",
      "            \"max\": 124,\n",
      "            \"avg\": 12.667209554831704\n",
      "        }\n",
      "    },\n",
      "    \"DEV\": {\n",
      "        \"dataset\": \"DEV\",\n",
      "        \"total_number_of_documents\": 3466,\n",
      "        \"number_of_documents_per_class\": {},\n",
      "        \"number_of_tokens_per_tag\": {},\n",
      "        \"number_of_tokens\": {\n",
      "            \"total\": 51578,\n",
      "            \"min\": 1,\n",
      "            \"max\": 109,\n",
      "            \"avg\": 14.881130986728216\n",
      "        }\n",
      "    }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "from flair.datasets import ColumnCorpus\n",
    "\n",
    "\n",
    "data_folder = 'conll2003'\n",
    "corpus = ColumnCorpus(data_folder, \n",
    "                      {0 : 'text', 3 : 'ner'},\n",
    "                      train_file='eng.train',\n",
    "                      test_file='eng.testb',\n",
    "                      dev_file='eng.testa')\n",
    "\n",
    "print(corpus.obtain_statistics())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:pytorch_transformers.modeling_bert:Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .\n",
      "INFO:pytorch_transformers.modeling_xlnet:Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .\n",
      "INFO:pytorch_transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at cache/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1\n",
      "INFO:pytorch_transformers.modeling_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json from cache at cache/b945b69218e98b3e2c95acf911789741307dec43c698d35fad11c1ae28bda352.d7a3af18ce3a2ab7c0f48f04dc8daff45ed9a3ed333b9e9a79d012a0dedf87a6\n",
      "INFO:pytorch_transformers.modeling_utils:Model config {\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"finetuning_task\": null,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"num_labels\": 9,\n",
      "  \"output_attentions\": false,\n",
      "  \"output_hidden_states\": false,\n",
      "  \"pruned_heads\": {},\n",
      "  \"torchscript\": false,\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"vocab_size\": 28996\n",
      "}\n",
      "\n",
      "INFO:pytorch_transformers.modeling_utils:loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin from cache at cache/35d8b9d36faaf46728a0192d82bf7d00137490cd6074e8500778afed552a67e5.3fadbea36527ae472139fe84cddaa65454d7429f12d543d80bfc3ad70de55ac2\n",
      "INFO:pytorch_transformers.modeling_utils:Weights of BertForTokenClassificationCustom not initialized from pretrained model: ['classifier.weight', 'classifier.bias']\n",
      "INFO:pytorch_transformers.modeling_utils:Weights from pretrained model not used in BertForTokenClassificationCustom: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n"
     ]
    }
   ],
   "source": [
    "from bert_sequence_tagger import SequenceTaggerBert, BertForTokenClassificationCustom\n",
    "from bert_sequence_tagger.bert_utils import make_bert_tag_dict_from_flair_corpus\n",
    "\n",
    "from pytorch_transformers import BertTokenizer, BertForTokenClassification\n",
    "\n",
    "\n",
    "bpe_tokenizer = BertTokenizer.from_pretrained('bert-base-cased', \n",
    "                                              cache_dir=CACHE_DIR, \n",
    "                                              do_lower_case=False)\n",
    "\n",
    "idx2tag, tag2idx = make_bert_tag_dict_from_flair_corpus(corpus)\n",
    "\n",
    "model = BertForTokenClassificationCustom.from_pretrained('bert-base-cased', \n",
    "                                                         cache_dir=CACHE_DIR, \n",
    "                                                         num_labels=len(tag2idx)).cuda()\n",
    "\n",
    "seq_tagger = SequenceTaggerBert(bert_model=model, bpe_tokenizer=bpe_tokenizer, \n",
    "                                idx2tag=idx2tag, tag2idx=tag2idx, max_len=MAX_LEN,\n",
    "                                pred_batch_size=PRED_BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch:   0%|          | 0/4 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:sequence_tagger_bert:Current learning rate: 3.828420055249356e-05\n",
      "INFO:sequence_tagger_bert:Train loss: 0.08635571034180195\n",
      "INFO:sequence_tagger_bert:Validation loss: 0.03886955951792047\n",
      "INFO:sequence_tagger_bert:Validation metrics: (0.9362345627152818,)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch:  25%|██▌       | 1/4 [02:33<07:41, 153.72s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:sequence_tagger_bert:Current learning rate: 2.6567066579477667e-05\n",
      "INFO:sequence_tagger_bert:Train loss: 0.022030211970791358\n",
      "INFO:sequence_tagger_bert:Validation loss: 0.032396575959865004\n",
      "INFO:sequence_tagger_bert:Validation metrics: (0.9428475374012439,)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch:  50%|█████     | 2/4 [05:02<05:04, 152.37s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:sequence_tagger_bert:Current learning rate: 1.4849932606461773e-05\n",
      "INFO:sequence_tagger_bert:Train loss: 0.009602448526282073\n",
      "INFO:sequence_tagger_bert:Validation loss: 0.031968948459534935\n",
      "INFO:sequence_tagger_bert:Validation metrics: (0.9522608841822156,)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch:  75%|███████▌  | 3/4 [07:31<02:31, 151.32s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:sequence_tagger_bert:Current learning rate: 3.1327986334458785e-06\n",
      "INFO:sequence_tagger_bert:Train loss: 0.003817373911283591\n",
      "INFO:sequence_tagger_bert:Validation loss: 0.03424496421454629\n",
      "INFO:sequence_tagger_bert:Validation metrics: (0.9553616378587012,)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch: 100%|██████████| 4/4 [10:05<00:00, 151.47s/it]\n"
     ]
    }
   ],
   "source": [
    "from bert_sequence_tagger.bert_utils import get_model_parameters, prepare_flair_corpus\n",
    "from bert_sequence_tagger.model_trainer_bert import ModelTrainerBert\n",
    "from bert_sequence_tagger.metrics import f1_entity_level, f1_token_level\n",
    "\n",
    "from pytorch_transformers import AdamW, WarmupLinearSchedule\n",
    "\n",
    "\n",
    "train_dataset = prepare_flair_corpus(corpus.train)\n",
    "val_dataset = prepare_flair_corpus(corpus.dev)\n",
    "\n",
    "optimizer = AdamW(get_model_parameters(model), \n",
    "                  lr=LEARNING_RATE, betas=(0.9, 0.999), \n",
    "                  eps =1e-6, weight_decay=0.01, correct_bias=True)\n",
    "lr_scheduler = WarmupLinearSchedule(optimizer, warmup_steps=0.1, \n",
    "                                    t_total=(len(corpus.train) / BATCH_SIZE)*MAX_N_EPOCHS)\n",
    "\n",
    "trainer = ModelTrainerBert(model=seq_tagger, \n",
    "                           optimizer=optimizer, \n",
    "                           lr_scheduler=lr_scheduler,\n",
    "                           train_dataset=train_dataset, \n",
    "                           val_dataset=val_dataset,\n",
    "                           update_scheduler='es',\n",
    "                           validation_metrics=[f1_entity_level],\n",
    "                           batch_size=BATCH_SIZE)\n",
    "\n",
    "trainer.train(epochs=MAX_N_EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:sequence_tagger_bert:Entity-level f1: 0.9146330719760374\n",
      "INFO:sequence_tagger_bert:Token-level f1: 0.9302211603259204\n"
     ]
    }
   ],
   "source": [
    "test_dataset = prepare_flair_corpus(corpus.test)\n",
    "\n",
    "_, __, test_metrics = seq_tagger.predict(test_dataset, evaluate=True, \n",
    "                                         metrics=[f1_entity_level, f1_token_level])\n",
    "logger.info(f'Entity-level f1: {test_metrics[1]}')\n",
    "logger.info(f'Token-level f1: {test_metrics[2]}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predicting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([['O', 'O', 'O', 'O', 'I-LOC', 'I-LOC', 'O', 'O'],\n",
       "  ['I-PER', 'I-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O']],\n",
       " [10.285023, 10.41528])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_tagger.predict([['We', 'are', 'living', 'in', 'New', 'York', 'city', '.'],\n",
    "                    ['Satya', 'Narayana', 'Nadella', 'is', 'an', 'engineer', 'and', 'business', 'executive', '.']])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_tagger.save_serialize('./model')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Restore model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "device = torch.device('cuda')\n",
    "n_gpu = torch.cuda.device_count()\n",
    "\n",
    "for i in range(n_gpu):\n",
    "    print(torch.cuda.get_device_name(i))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bert_sequence_tagger.bert_for_token_classification_custom import BertForTokenClassificationCustom\n",
    "from bert_sequence_tagger import SequenceTaggerBert\n",
    "\n",
    "\n",
    "seq_tagger = SequenceTaggerBert.load_serialized('./model', BertForTokenClassificationCustom)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_tagger.predict([['We', 'are', 'living', 'in', 'New', 'York', 'city', '.'],\n",
    "                    ['Satya', 'Narayana', 'Nadella', 'is', 'an', 'engineer', 'and', 'business', 'executive', '.']])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
